{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# Visualizing the effect of $L_1/L_2$ regularization\n",
    "\n",
    "We use a toy example with two weights $(w_0, w_1)$ to illustrate the effect $L_1$ and $L_2$ regularization has on the solution to a loss minimization problem.\n",
    "\n",
    "## Table of Contents\n",
    "\n",
    "1. [Draw the data loss and the L1/L2L1/L2 regularization curves](#Draw-the-data-loss-and-the-%24L_1%2FL_2%24-regularization-curves)\n",
    "2. [Plot the training progress](#Plot-the-training-progress)\n",
    "3. [L1L1 -norm regularization leads to \"near-sparsity\"](#%24L_1%24-norm-regularization-leads-to-%22near-sparsity%22)\n",
    "4. [References](#References)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "skip"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import axes3d\n",
    "import matplotlib.animation as animation\n",
    "import matplotlib.patches as mpatches\n",
    "\n",
    "from torch.autograd import Variable\n",
    "import torch\n",
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import set_matplotlib_formats\n",
    "set_matplotlib_formats('pdf', 'png')\n",
    "plt.rcParams['savefig.dpi'] = 75\n",
    "\n",
    "plt.rcParams['figure.autolayout'] = False\n",
    "plt.rcParams['figure.figsize'] = 10, 6\n",
    "plt.rcParams['axes.labelsize'] = 18\n",
    "plt.rcParams['axes.titlesize'] = 20\n",
    "plt.rcParams['font.size'] = 16\n",
    "plt.rcParams['lines.linewidth'] = 2.0\n",
    "plt.rcParams['lines.markersize'] = 8\n",
    "plt.rcParams['legend.fontsize'] = 14\n",
    "\n",
    "# plt.rcParams['text.usetex'] = True\n",
    "plt.rcParams['font.family'] = \"Sans Serif\"\n",
    "plt.rcParams['font.serif'] = \"cm\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Draw the data loss and the $L_1/L_2$ regularization curves\n",
    "\n",
    "We choose just a very simple convex loss function for illustration (in blue), which has its minimum at W=(3,2):\n",
    "\n",
    "<!-- Equation labels as ordinary links -->\n",
    "<div id=\"eq:loss\"></div>\n",
    "\n",
    "$$\n",
    "\\begin{equation}\n",
    "loss(W) = 0.5(w_0-3)^2 + 2.5(w_1-2)^2\n",
    "\\label{eq:loss} \\tag{1}\n",
    "\\end{equation}\n",
    "$$\n",
    "\n",
    "The L1-norm regularizer (aka lasso regression; lasso: least absolute shrinkage and selection operator):"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div id=\"eq:l1\"></div>\n",
    "\n",
    "$$\n",
    "\\begin{equation}\n",
    "L_1(W) = \\sum_{i=1}^{|W|} |w_i|\n",
    "\\label{eq:eq1} \\tag{2}\n",
    "\\end{equation}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The L2 regularizer (aka Ridge Regression and Tikhonov regularization), is not is the square of the L2-norm, and this little nuance is sometimes overlooked."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div id=\"eq:l2\"></div>\n",
    "\n",
    "$$\n",
    "\\begin{equation}\n",
    "L_2(W) = ||W||_2^2= \\sum_{i=1}^{|W|} w_i^2\n",
    "\\label{eq:eq3} \\tag{3}\n",
    "\\end{equation}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_fn(W):\n",
    "    return 0.5*(W[0]-3)**2 + 2.5*(W[1]-2)**2\n",
    "\n",
    "# L1 regularization\n",
    "def L1_regularization(W):\n",
    "    return abs(W[0]) + abs(W[1])\n",
    "\n",
    "def L2_regularization(W):\n",
    "    return np.sqrt(W[0]**2 + W[1]**2)\n",
    "\n",
    "fig = plt.figure(figsize=(10,10))\n",
    "ax = fig.gca(projection=\"3d\")\n",
    "\n",
    "xmesh, ymesh = np.mgrid[-3:9:50j,-2:6:50j]\n",
    "loss_mesh = loss_fn(np.array([xmesh, ymesh]))\n",
    "ax.plot_surface(xmesh, ymesh, loss_mesh);\n",
    "\n",
    "l1_mesh = L1_regularization(np.array([xmesh, ymesh]))\n",
    "ax.plot_surface(xmesh, ymesh, l1_mesh);\n",
    "\n",
    "l2_mesh = L2_regularization(np.array([xmesh, ymesh]))\n",
    "ax.plot_surface(xmesh, ymesh, l2_mesh);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot the training progress\n",
    "\n",
    "<br>\n",
    "The diamond contour lines are the values of the L1 regularization.  Since this is a contour diagram, all the points on a contour line have the same L1 value.  <br>\n",
    "In otherwords, for all points on a contour line:\n",
    "$$L_1(W) = \\left|w_0\\right| + \\left|w_1\\right| == constant$$ \n",
    "This is called the L1-ball.<br>\n",
    "L2-balls maintain the equation: $$L_2(W) = w_0^2 + w_1^2 == constant$$\n",
    "<br>\n",
    "The oval contour line are the values of the data loss function.  The regularized solution tries to find weights that satisfy both the data loss and the regularization loss. \n",
    "<br><br>\n",
    "```alpha``` and ```beta``` control the strengh of the regularlization loss versus the data loss.\n",
    "To see how the regularizers act \"in the wild\", set ```alpha``` and ```beta``` to a high value like 10.  The regularizers will then dominate the loss, and you will see how each of the regularizers acts.\n",
    "<br>\n",
    "Experiment with the value of alpha to see how it works."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "initial_guess = torch.Tensor([8,5]) \n",
    "W = Variable(initial_guess, requires_grad=True)\n",
    "W_l1_reg = Variable(initial_guess.clone(), requires_grad=True) \n",
    "W_l2_reg = Variable(initial_guess.clone(), requires_grad=True) \n",
    "\n",
    "def L1_regularization(W):\n",
    "    return W.norm(1)\n",
    "\n",
    "def L2_regularization(W):\n",
    "    return W.pow(2).sum()\n",
    "\n",
    "lr = 0.04\n",
    "alpha = 0.75 # 1.5 # 4 # 0.4\n",
    "beta = 0.75\n",
    "num_steps = 1000\n",
    "\n",
    "def train(W, lr, alpha, beta, num_steps):\n",
    "    guesses = []\n",
    "    for i in range(num_steps):\n",
    "        # Zero the gradients of the weights\n",
    "        if W.grad is not None:\n",
    "            W.grad.data.zero_()\n",
    "\n",
    "        # Compute the loss and the gradients of W\n",
    "        loss = loss_fn(W) + alpha * L1_regularization(W) + beta * L2_regularization(W)\n",
    "        loss.backward()\n",
    "\n",
    "        # Update W\n",
    "        W.data = W.data - lr * W.grad.data\n",
    "        guesses.append(W.data.numpy())\n",
    "    return guesses\n",
    "\n",
    "# Train the weights without regularization\n",
    "guesses = train(W, lr, alpha=0, beta=0, num_steps=num_steps)\n",
    "# ...and with L1 regularization\n",
    "guesses_l1_reg = train(W_l1_reg, lr, alpha=alpha, beta=0, num_steps=num_steps)\n",
    "guesses_l2_reg = train(W_l2_reg, lr, alpha=0,     beta=beta, num_steps=num_steps)\n",
    "\n",
    "\n",
    "fig = plt.figure(figsize=(15,10))\n",
    "plt.axis(\"equal\")\n",
    "\n",
    "# Draw the contour maps of the data-loss and regularization loss\n",
    "CS = plt.contour(xmesh, ymesh, loss_mesh, 10, cmap=plt.cm.bone)\n",
    "# Draw the L1-balls\n",
    "CS2 = plt.contour(xmesh, ymesh, l1_mesh, 10, linestyles='dashed', levels=list(range(5)));\n",
    "# Draw the L2-balls\n",
    "CS3 = plt.contour(xmesh, ymesh, l2_mesh, 10, linestyles='dashed', levels=list(range(5)));\n",
    "\n",
    "# Add green contour lines near the loss minimum\n",
    "CS4 = plt.contour(CS, levels=[0.25, 0.5], colors='g')\n",
    "\n",
    "# Place a green dot at the data loss minimum, and an orange dot at the origin\n",
    "plt.scatter(3,2, color='g')\n",
    "plt.scatter(0,0, color='r')\n",
    "\n",
    "# Color bars and labels\n",
    "plt.xlabel(\"W[0]\")\n",
    "plt.ylabel(\"W[1]\")\n",
    "CB  = plt.colorbar(CS, label=\"data loss\", shrink=0.8, extend='both')\n",
    "CB2 = plt.colorbar(CS2, label=\"reg loss\", shrink=0.8, extend='both')\n",
    "# Label the contour lines\n",
    "plt.clabel(CS, fmt = '%2d', colors = 'k', fontsize=14) #contour line labels\n",
    "plt.clabel(CS2, fmt = '%2d', colors = 'red', fontsize=14) #contour line labels\n",
    "\n",
    "# Plot the two sets of weights (green are weights w/o regularization; red are L1; blue are L2)\n",
    "it_array = np.array(guesses)\n",
    "unregularized = plt.plot(it_array.T[0], it_array.T[1], \"o\", color='g')\n",
    "it_array = np.array(guesses_l1_reg)\n",
    "l1 = plt.plot(it_array.T[0], it_array.T[1], \"+\", color='r')\n",
    "it_array = np.array(guesses_l2_reg)\n",
    "l2 = plt.plot(it_array.T[0], it_array.T[1], \"+\", color='b')\n",
    "\n",
    "# Legends require a proxy artists in this case\n",
    "unregularized = mpatches.Patch(color='g', label='unregularized')\n",
    "l1 = mpatches.Patch(color='r', label='L1')\n",
    "l2 = mpatches.Patch(color='b', label='L2')\n",
    "plt.legend(handles=[unregularized, l1, l2])\n",
    "\n",
    "# Finally add the axes, so we can see how far we are from the sparse solution.\n",
    "plt.axhline(0, color='orange')\n",
    "plt.axvline(0, color='orange')\n",
    "\n",
    "print(\"solution: loss(%.3f, %.3f)=%.3f\" % (W.data[0], W.data[1], loss_fn(W)))\n",
    "print(\"solution: l1_loss(%.3f, %.3f)=%.3f\" % (W.data[0], W.data[1], L1_regularization(W)))\n",
    "print(\"regularized solution: loss(%.3f, %.3f)=%.3f\" % (W_l1_reg.data[0], W_l1_reg.data[1], loss_fn(W_l1_reg)))\n",
    "print(\"regularized solution: l1_loss(%.3f, %.3f)=%.3f\" % (W_l1_reg.data[0], W_l1_reg.data[1], L1_regularization(W_l1_reg)))\n",
    "print(\"regularized solution: l2_loss(%.3f, %.3f)=%.3f\" % (W_l2_reg.data[0], W_l2_reg.data[1], L2_regularization(W_l2_reg)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## $L_1$-norm regularization leads to \"near-sparsity\"\n",
    "\n",
    "$L_1$-norm regularization is often touted as sparsity inducing, but it actually creates solutions that oscillate around 0, not exactly 0 as we'd like. <br>\n",
    "To demonstrate this, we redefine our toy loss function so that the optimal solution for $w_0$ is close to 0 (0.3)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_fn(W):\n",
    "    return 0.5*(W[0]-0.3)**2 + 2.5*(W[1]-2)**2\n",
    "\n",
    "# Train again\n",
    "W = Variable(initial_guess, requires_grad=True)\n",
    "guesses_l1_reg = train(W, lr, alpha=alpha, beta=0, num_steps=num_steps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When we draw the progress of the weight training, we see that $W_0$ is gravitating towards zero.<br>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the contour maps of the data-loss and regularization loss\n",
    "CS = plt.contour(xmesh, ymesh, loss_mesh, 10, cmap=plt.cm.bone)\n",
    "\n",
    "# Plot the progress of the training process\n",
    "it_array = np.array(guesses_l1_reg)\n",
    "l1 = plt.plot(it_array.T[0], it_array.T[1], \"+\", color='r')\n",
    "\n",
    "# Finally add the axes, so we can see how far we are from the sparse solution.\n",
    "plt.axhline(0, color='orange')\n",
    "plt.axvline(0, color='orange');\n",
    "plt.xlabel(\"W[0]\")\n",
    "plt.ylabel(\"W[1]\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "But if we look closer at what happens to $w_0$ in the last 100 steps of the training, we see that is oscillates around 0, but never quite lands there.  Why?<br>\n",
    "Well, $dL1/dw_0$ is a constant (```lr * alpha``` in our case), so the weight update step:<br>\n",
    "```W.data = W.data - lr * W.grad.data``` <br>\n",
    "can be expanded to <br>\n",
    "```W.data = W.data - lr * (alpha + dloss_fn(W)/dW0)``` where ```dloss_fn(W)/dW0)``` <br>is the gradient of loss_fn(W) with respect to $w_0$. <br>\n",
    "The oscillations are not constant (although they do have a rythm) because they are influenced by this latter loss."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "it_array = np.array(guesses_l1_reg[int(0.9*num_steps):])\n",
    "for i in range(len(it_array)):\n",
    "    print(\"%.4f\\t(diff=%.4f)\" % (it_array.T[0][i], abs(it_array.T[0][i]-it_array.T[0][i-1])))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## References\n",
    "\n",
    "<div id=\"Goodfellow-et-al-2016\"></div> **Ian Goodfellow and Yoshua Bengio and Aaron Courville**. \n",
    "    [*Deep Learning*](http://www.deeplearningbook.org),\n",
    "    MIT Press,\n",
    "    2016."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
